{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "cb526c3a-a206-4293-91d5-cd9fde76c274",
   "metadata": {},
   "source": [
    "自定义的输出规则,让LLM不仅只是文本聊天,还可以与现实各种系统无缝对接"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e707d792-81da-4b5f-9c63-4cb7d2ebbc9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 讲笑话机器人：希望每次根据指令，可以输出一个这样的笑话(小明是怎么死的？笨死的)\n",
    "from langchain.llms import OpenAI\n",
    "from langchain.output_parsers import PydanticOutputParser\n",
    "from langchain.prompts import PromptTemplate\n",
    "from langchain.pydantic_v1 import BaseModel,Field,validator\n",
    "from typing import List\n",
    "import os\n",
    "\n",
    "api_base = os.getenv(\"OPENAI_PROXY\")\n",
    "api_key = os.getenv(\"OPENAI_API_KEY\")\n",
    "\n",
    "# 构造LLM\n",
    "model = OpenAI(\n",
    "    model = \"gpt-3.5-turbo-instruct\",\n",
    "    temperature=0,\n",
    "    openai_api_key = api_key,\n",
    "    openai_api_base = api_base,\n",
    ")\n",
    "\n",
    "# 定义个数据模型，用来描述最终的实例结构\n",
    "class Joke(BaseModel):\n",
    "    setup:str = Field(description=\"设置笑话的问题\")\n",
    "    punchline:str = Field(description=\"回答笑话的答案\")\n",
    "\n",
    "    #验证问题是否符合要求\n",
    "    @validator(\"setup\")\n",
    "    def question_mark(cls, field):\n",
    "        if field[-1] != \"？\":\n",
    "            raise ValueError(\"不符合预期的问题格式!\")\n",
    "        return field\n",
    "\n",
    "# 将Joke数据模型传入\n",
    "parser = PydanticOutputParser(pydantic_object=Joke)\n",
    "\n",
    "prompt = PromptTemplate(\n",
    "    template = \"回答用户的输入.\\n{format_instructions}\\n{query}\\n\",\n",
    "    input_variables = [\"query\"],\n",
    "    partial_variables = {\"format_instructions\":parser.get_format_instructions()}\n",
    ")\n",
    "\n",
    "prompt_and_model = prompt | model\n",
    "out_put = prompt_and_model.invoke({\"query\":\"给我讲一个笑话\"})\n",
    "print(\"out_put:\",out_put)\n",
    "parser.invoke(out_put)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1ad5294-52c1-4cb9-bfd3-02624ab6124c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# LLM的输出格式化成python list形式，类似['a','b','c']\n",
    "from langchain.output_parsers import CommaSeparatedListOutputParser\n",
    "from langchain.prompts import PromptTemplate\n",
    "from langchain.llms import OpenAI\n",
    "import os\n",
    "\n",
    "api_base = os.getenv(\"OPENAI_PROXY\")\n",
    "api_key = os.getenv(\"OPENAI_API_KEY\")\n",
    "\n",
    "# 构造LLM\n",
    "model = OpenAI(\n",
    "    model = \"gpt-3.5-turbo-instruct\",\n",
    "    temperature=0,\n",
    "    openai_api_key = api_key,\n",
    "    openai_api_base = api_base,\n",
    ")\n",
    "\n",
    "# 输出的数据模型\n",
    "parser = CommaSeparatedListOutputParser()\n",
    "\n",
    "prompt = PromptTemplate(\n",
    "    template = \"列出5个{subject}.\\n{format_instructions}\",\n",
    "    input_variables = [\"subject\"],\n",
    "    partial_variables = {\"format_instructions\":parser.get_format_instructions()}\n",
    ")\n",
    "\n",
    "_input = prompt.format(subject=\"常见的小狗的名字\")\n",
    "output = model(_input)\n",
    "print(output)\n",
    "#格式化\n",
    "parser.parse(output)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (langchain)",
   "language": "python",
   "name": "langchain-env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
